// STL
#include <complex>                      // std::complex<>, ::real(), ::imag()
#include <stdexcept>                    // std::runtime_error()
#include <string>                       // std::to_string()
#include <tuple>                        // std::tie()
#include <vector>                       // std::vector<>

// jScience
#include "jScience/linalg.hpp"          // Vector<>, Matrix<>, outer_product(), inverse(), eigen()
#include "jScience/stl/MultiVector.hpp" // MultiVector<>
#include "jutility.hpp"                 // sort_indices_desc()

// stats++
#include "statsxx/statistics.hpp"       // statistics::sum()


//
// DESC: Perform C-class LDA.
//
// -----
//
// INPUT:
//          int                                      nC : number of classes
//          -----
//          int                                      nd : number of dimensions
//          =====
//          std::vector<std::vector<Vector<double>>> x  : sets of Vector<>s, separated by classes
//
// OUTPUT:
//          std::vector<Vector<double>>              w  : optimal projection Vector<>s
//
// =====
//
// NOTE: See:
//
//     http://www.sci.utah.edu/~shireen/pdfs/tutorials/Elhabian_LDA09.pdf
//
inline std::vector<Vector<double>> LDA(
                                       const int                                       nC,
                                       // -----
                                       const int                                       nd,
                                       // =====
                                       const std::vector<std::vector<Vector<double>>> &x
                                       )
{
    std::vector<Vector<double>> w;

    //=========================================================
    // ERROR CHECKS + INFER INFORMATION ABOUT DATASET
    //=========================================================

    if( nC != x.size() )
    {
        throw std::runtime_error( "error in LDA(): ( nC != x.size() )" );
    }

    std::vector<int> N(nC);

    for( int i = 0; i < nC; ++i )
    {
        N[i] = x[i].size();

        if( N[i] == 0 )
        {
            throw std::runtime_error( ("error in LDA(): no data for class " + std::to_string(i)) );
        }

        for( auto &x_j : x[i] )
        {
            if( x_j.size() != nd )
            {
                throw std::runtime_error( "error in LDA(): ( x_j.size() != nd )" );
            }
        }
    }

    //=========================================================
    // CALCULATE MEANS
    //=========================================================

    // Vector<> MEANS
    std::vector<Vector<double>> mu(nC, Vector<double>(nd, 0.));

    for( int i = 0; i < nC; ++i )
    {
        for( auto &x_j : x[i] )
        {
            mu[i] += x_j;
        }

        mu[i] /= x[i].size();
    }

    // CLASSES MEAN
    Vector<double> mu_C(nd, 0.);

    for( int i = 0; i < nC; ++i )
    {
        mu_C += ( N[i]*mu[i] );
    }

    mu_C /= statistics::sum(N);

    //=========================================================
    // SCATTER MATRICES
    //=========================================================

    // WITHIN-CLASS SCATTER
    Matrix<double> S_W(nd,nd, 0.);

    for( int i = 0; i < nC; ++i )
    {
        Matrix<double> S_i(nd,nd, 0.);

        for( auto &x_j : x[i] )
        {
            Vector<double> x_mu = x_j -  mu[i];

            S_i += outer_product(x_mu, x_mu);
        }

        S_W += S_i;
    }

    int            INFO0;   // DGETRF exit code
    int            INFO;    // DGETRI exit code
    Matrix<double> S_W_inv; // Matrix<> inverse
    std::tie(
             INFO0,
             INFO,
             S_W_inv
             ) = inverse(
                         S_W
                         );

    if( (INFO0 != 0) || (INFO != 0) )
    {
        throw std::runtime_error( "error in LDA(): ( (INFO0 != 0) || (INFO != 0) ) on return from inverse()" );
    }

    // BETWEEN-CLASS SCATTER
    Matrix<double> S_B(nd,nd, 0.);

    for( int i = 0; i < nC; ++i )
    {
        Vector<double> mu_mu_C = mu[i] - mu_C;

        S_B += ( static_cast<double>(N[i])*outer_product(mu_mu_C, mu_mu_C) );
    }

    //=========================================================
    // PROJECTION MATRIX
    //=========================================================

    int                               _INFO; // DGEEV exit code
    std::vector<std::complex<double>> _W;    // eigenvalues
    MultiVector<std::complex<double>> VL;    // left eigenvectors
    MultiVector<std::complex<double>> VR;    // right eigenvectors
    std::tie(
             _INFO,
             _W,
             VL,
             VR
             ) = eigen(
                       (S_W_inv*S_B)
                       );

    if( _INFO != 0 )
    {
        throw std::runtime_error( "error in LDA(): ( _INFO != 0 ) on return from eigen()" );
    }

    // SORT BY (ASSUMED REAL) EIGENVALUES
    std::vector<double> W_r;

    for( auto _W_i : _W )
    {
        if( std::imag(_W_i) != 0. )
        {
            throw std::runtime_error( "error in LDA(): imaginary eigenvalue" );
        }

        W_r.push_back(std::real(_W_i));
    }

    for( auto &idx : sort_indices_desc(W_r) )
    {
        Vector<double> w_idx(nd);

        // NOTE: Right eigenvectors are stored in the columns of VR.
        //
        for( int j = 0; j < nd; ++j )
        {
            w_idx(j) = std::real(VR(j,idx));
        }

        w.push_back(w_idx);
    }

    //=========================================================
    // RETURN
    //=========================================================

    return w;
}
